In [5]:
import tensorflow_datasets as tfds
import tensorflow as tf

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)
In [6]:
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
In [7]:
plt.figure(figsize=(10, 5))


get_label_name = metadata.features['label'].int2str

for idx, (image, label) in enumerate(raw_train.take(10)):
    plt.subplot(2, 5, idx+1)
    plt.imshow(image)
    plt.title(f'label {label}: {get_label_name(label)}')
    plt.axis('off')
In [8]:
IMG_SIZE = 160 # All images will be resized to 160x160

def format_example(image, label):
    image = tf.cast(image, tf.float32)  # image=float(image)같은 타입캐스팅의  텐서플로우 버전입니다.
    image = (image/127.5) - 1
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return image, label
In [9]:
train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)
In [31]:
plt.figure(figsize=(10, 5))


get_label_name = metadata.features['label'].int2str

for idx, (image, label) in enumerate(train.take(10)):
    plt.subplot(2, 5, idx+1)
    image = (image + 1) / 2
    plt.imshow(image)
    plt.title(f'label {label}: {get_label_name(label)}')
    plt.axis('off')
In [32]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, MaxPooling2D
In [42]:
model = Sequential([
    Conv2D(filters=16, kernel_size=3, padding='same', activation='relu', input_shape=(160, 160, 3)),
    MaxPooling2D(),
    Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'),
    MaxPooling2D(),
    Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'),
    MaxPooling2D(),
    Flatten(),
    Dense(units=512, activation='relu'),
    Dense(units=5, activation='softmax')
])
In [43]:
import numpy as np

image = np.array([[1, 2], [3, 4]])
print(image.shape)
image
(2, 2)
Out[43]:
array([[1, 2],
       [3, 4]])
In [44]:
image.flatten()
Out[44]:
array([1, 2, 3, 4])
In [45]:
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=['accuracy'])
In [46]:
BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000
In [47]:
train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)
In [48]:
for image_batch, label_batch in train_batches.take(1):
    pass

image_batch.shape

#print(train_batches)
#print(validation_batches)
Out[48]:
TensorShape([32, 160, 160, 3])
In [49]:
validation_steps = 3
loss0, accuracy0 = model.evaluate(validation_batches, steps=validation_steps)

print("initial loss: {:.4f}".format(loss0))
print("initial accuracy: {:.4f}".format(accuracy0))
3/3 [==============================] - 0s 9ms/step - loss: 1.6261 - accuracy: 0.1771
initial loss: 1.6261
initial accuracy: 0.1771
In [50]:
EPOCHS = 10
history = model.fit(train_batches,
                    epochs=EPOCHS,
                    validation_data=validation_batches)
Epoch 1/10
92/92 [==============================] - 3s 28ms/step - loss: 1.3358 - accuracy: 0.4261 - val_loss: 1.0922 - val_accuracy: 0.5504
Epoch 2/10
92/92 [==============================] - 2s 26ms/step - loss: 1.0502 - accuracy: 0.5743 - val_loss: 1.0961 - val_accuracy: 0.5177
Epoch 3/10
92/92 [==============================] - 2s 27ms/step - loss: 0.9292 - accuracy: 0.6301 - val_loss: 1.0714 - val_accuracy: 0.5531
Epoch 4/10
92/92 [==============================] - 3s 29ms/step - loss: 0.8285 - accuracy: 0.7016 - val_loss: 1.0379 - val_accuracy: 0.5749
Epoch 5/10
92/92 [==============================] - 3s 29ms/step - loss: 0.7403 - accuracy: 0.7183 - val_loss: 1.0196 - val_accuracy: 0.6185
Epoch 6/10
92/92 [==============================] - 2s 26ms/step - loss: 0.6419 - accuracy: 0.7694 - val_loss: 0.8836 - val_accuracy: 0.6649
Epoch 7/10
92/92 [==============================] - 2s 26ms/step - loss: 0.5645 - accuracy: 0.8099 - val_loss: 1.0081 - val_accuracy: 0.5913
Epoch 8/10
92/92 [==============================] - 2s 26ms/step - loss: 0.4723 - accuracy: 0.8396 - val_loss: 0.9511 - val_accuracy: 0.6294
Epoch 9/10
92/92 [==============================] - 2s 26ms/step - loss: 0.4131 - accuracy: 0.8672 - val_loss: 1.0191 - val_accuracy: 0.6240
Epoch 10/10
92/92 [==============================] - 2s 26ms/step - loss: 0.3537 - accuracy: 0.8931 - val_loss: 1.1095 - val_accuracy: 0.6076
In [51]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss=history.history['loss']
val_loss=history.history['val_loss']

epochs_range = range(EPOCHS)

plt.figure(figsize=(4, 4))
plt.subplot(1, 1, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

#plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
In [52]:
for image_batch, label_batch in test_batches.take(1):
    images = image_batch
    labels = label_batch
    predictions = model.predict(image_batch)
    pass

predictions
Out[52]:
array([[2.52832135e-04, 5.38016087e-04, 9.63347971e-01, 8.10046773e-03,
        2.77606100e-02],
       [6.40217543e-01, 3.22125465e-01, 3.45303342e-02, 9.85061124e-05,
        3.02810688e-03],
       [1.90647259e-01, 5.20059206e-02, 5.38697779e-01, 2.16933101e-01,
        1.71590818e-03],
       [8.98692384e-03, 2.87711471e-01, 6.23172283e-01, 3.64238885e-03,
        7.64869601e-02],
       [5.78166068e-01, 1.92818522e-01, 9.05924216e-02, 1.72735881e-02,
        1.21149383e-01],
       [2.96536535e-01, 3.65053147e-01, 1.61174595e-01, 1.32185966e-02,
        1.64017186e-01],
       [1.94080480e-04, 2.46301908e-02, 9.67658341e-01, 5.15324820e-04,
        7.00211665e-03],
       [1.15798436e-01, 9.69402194e-02, 3.46835762e-01, 2.86086295e-02,
        4.11816925e-01],
       [1.59341877e-03, 1.05547355e-04, 2.96385922e-02, 9.68611538e-01,
        5.09743295e-05],
       [9.48890626e-01, 3.88295352e-02, 5.09355357e-03, 8.75667465e-05,
        7.09877769e-03],
       [9.46866453e-01, 1.85349211e-02, 2.51511615e-02, 3.82965198e-04,
        9.06451326e-03],
       [5.84821440e-02, 3.50442007e-02, 3.84957850e-01, 4.66118395e-01,
        5.53973950e-02],
       [1.03033509e-03, 2.01499108e-02, 3.96573275e-01, 4.90770221e-01,
        9.14762542e-02],
       [1.62184727e-03, 3.22985649e-02, 7.31879473e-01, 1.00478977e-02,
        2.24152237e-01],
       [1.08076729e-01, 5.87864295e-02, 3.21560413e-01, 5.68003161e-03,
        5.05896389e-01],
       [7.21855927e-03, 7.11735862e-04, 2.25104839e-01, 7.66285419e-01,
        6.79439108e-04],
       [4.16492230e-06, 1.98544760e-04, 9.88362968e-01, 5.63554022e-05,
        1.13781076e-02],
       [2.64011383e-01, 5.58990194e-03, 3.70046496e-02, 6.91495836e-01,
        1.89829932e-03],
       [2.01639742e-01, 3.45539898e-01, 2.37432063e-01, 4.75826953e-03,
        2.10630044e-01],
       [4.47703613e-04, 2.36064233e-02, 8.76172125e-01, 2.75400198e-05,
        9.97461751e-02],
       [2.51203106e-04, 1.09957480e-04, 9.78543818e-01, 8.21511075e-03,
        1.28799248e-02],
       [3.81780326e-01, 5.39589345e-01, 3.92865874e-02, 2.61724107e-02,
        1.31713506e-02],
       [1.43412547e-03, 4.15829182e-01, 5.58871090e-01, 3.08487518e-03,
        2.07807235e-02],
       [2.80288190e-01, 5.92788339e-01, 8.97711515e-02, 5.39092254e-03,
        3.17613818e-02],
       [2.49150414e-02, 7.09640145e-01, 2.33453646e-01, 1.23032500e-04,
        3.18681374e-02],
       [5.26005216e-02, 3.33530232e-02, 8.01757097e-01, 1.11194968e-01,
        1.09444151e-03],
       [2.41643772e-03, 1.34428721e-02, 8.01131368e-01, 9.49926743e-06,
        1.82999939e-01],
       [3.38705555e-02, 3.68147880e-01, 3.32776874e-01, 3.50991786e-02,
        2.30105504e-01],
       [4.07020252e-06, 5.46927156e-04, 9.84821141e-01, 1.30373865e-05,
        1.46149034e-02],
       [5.41259497e-02, 5.59208333e-01, 1.60649255e-01, 2.22648934e-01,
        3.36745824e-03],
       [6.28494143e-01, 2.97198296e-01, 6.73923120e-02, 1.34501734e-03,
        5.57023380e-03],
       [3.82469267e-01, 1.15222648e-01, 4.21848387e-01, 7.33610184e-04,
        7.97260702e-02]], dtype=float32)
In [53]:
import numpy as np
predictions = np.argmax(predictions, axis=1)
predictions
Out[53]:
array([2, 0, 2, 2, 0, 1, 2, 4, 3, 0, 0, 3, 3, 2, 4, 3, 2, 3, 1, 2, 2, 1,
       2, 1, 1, 2, 2, 1, 2, 1, 0, 2])
In [54]:
plt.figure(figsize=(20, 12))

for idx, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
    plt.subplot(4, 8, idx+1)
    image = (image + 1) / 2
    plt.imshow(image)
    correct = label == prediction
    title = f'real: {label} / pred :{prediction}\n {correct}!'
    if not correct:
        plt.title(title, fontdict={'color': 'red'})
    else:
        plt.title(title, fontdict={'color': 'blue'})
    plt.axis('off')
In [55]:
count = 0   # 정답을 맞춘 개수
for image, label, prediction in zip(images, labels, predictions):
    image = (image + 1) / 2
    correct = (label == prediction)
    if correct:
        count = count + 1

print(count / 32 * 100, '%')
43.75 %
In [56]:
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

# Create the base model from the pre-trained model VGG16
base_model = tf.keras.applications.VGG16(input_shape=IMG_SHAPE,
                                         include_top=False,
                                         weights='imagenet')
In [57]:
image_batch.shape
Out[57]:
TensorShape([32, 160, 160, 3])
In [58]:
feature_batch = base_model(image_batch)
feature_batch.shape
Out[58]:
TensorShape([32, 5, 5, 512])
In [59]:
base_model.summary()
Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 160, 160, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 160, 160, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 80, 80, 64)        0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 80, 80, 128)       73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 80, 80, 128)       147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 40, 40, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 40, 40, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 40, 40, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 40, 40, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 20, 20, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 20, 20, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 10, 10, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 10, 10, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 10, 10, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 10, 10, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 5, 5, 512)         0         
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0
_________________________________________________________________
In [60]:
feature_batch.shape
Out[60]:
TensorShape([32, 5, 5, 512])
In [61]:
import numpy as np

image = np.array([[1, 2],
                  [3, 4]])

flattened_image = image.flatten()

print("Original image:\n", image)
print("Original image shape:", image.shape)
print()
print("Flattened image:\n", flattened_image)
print("Flattened image shape:", flattened_image.shape)
Original image:
 [[1 2]
 [3 4]]
Original image shape: (2, 2)

Flattened image:
 [1 2 3 4]
Flattened image shape: (4,)
In [62]:
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
In [63]:
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 512)
In [82]:
dense_layer = tf.keras.layers.Dense(512, activation='relu')
prediction_layer = tf.keras.layers.Dense(5, activation='softmax')

# feature_batch_averag가 dense_layer를 거친 결과가 다시 prediction_layer를 거치게 되면
prediction_batch = prediction_layer(dense_layer(feature_batch_average))  
print(prediction_batch.shape)
(32, 5)
In [83]:
base_model.trainable = False
In [84]:
model = tf.keras.Sequential([
  base_model,
  global_average_layer,
  dense_layer,
  prediction_layer
])
In [85]:
model.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg16 (Model)                (None, 5, 5, 512)         14714688  
_________________________________________________________________
global_average_pooling2d (Gl (None, 512)               0         
_________________________________________________________________
dense_8 (Dense)              (None, 512)               262656    
_________________________________________________________________
dense_9 (Dense)              (None, 5)                 2565      
=================================================================
Total params: 14,979,909
Trainable params: 265,221
Non-trainable params: 14,714,688
_________________________________________________________________
In [86]:
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=['accuracy'])
In [88]:
validation_steps=10
loss0, accuracy0 = model.evaluate(validation_batches, steps = validation_steps)

print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
10/10 [==============================] - 1s 53ms/step - loss: 1.6564 - accuracy: 0.1937
initial loss: 1.66
initial accuracy: 0.19
In [111]:
EPOCHS = 20   # 이번에는 이전보다 훨씬 빠르게 수렴되므로 5Epoch이면 충분합니다.

history = model.fit(train_batches,
                    epochs=EPOCHS,
                    validation_data=validation_batches)
Epoch 1/20
92/92 [==============================] - 7s 75ms/step - loss: 0.4535 - accuracy: 0.8522 - val_loss: 0.5154 - val_accuracy: 0.8174
Epoch 2/20
92/92 [==============================] - 7s 73ms/step - loss: 0.4421 - accuracy: 0.8559 - val_loss: 0.5081 - val_accuracy: 0.8338
Epoch 3/20
92/92 [==============================] - 7s 74ms/step - loss: 0.4329 - accuracy: 0.8597 - val_loss: 0.5070 - val_accuracy: 0.8229
Epoch 4/20
92/92 [==============================] - 7s 74ms/step - loss: 0.4242 - accuracy: 0.8651 - val_loss: 0.5075 - val_accuracy: 0.8229
Epoch 5/20
92/92 [==============================] - 7s 74ms/step - loss: 0.4165 - accuracy: 0.8658 - val_loss: 0.4915 - val_accuracy: 0.8229
Epoch 6/20
92/92 [==============================] - 7s 75ms/step - loss: 0.4077 - accuracy: 0.8665 - val_loss: 0.4878 - val_accuracy: 0.8311
Epoch 7/20
92/92 [==============================] - 7s 74ms/step - loss: 0.4014 - accuracy: 0.8699 - val_loss: 0.4857 - val_accuracy: 0.8338
Epoch 8/20
92/92 [==============================] - 7s 74ms/step - loss: 0.3957 - accuracy: 0.8757 - val_loss: 0.4715 - val_accuracy: 0.8365
Epoch 9/20
92/92 [==============================] - 7s 75ms/step - loss: 0.3865 - accuracy: 0.8730 - val_loss: 0.4793 - val_accuracy: 0.8229
Epoch 10/20
92/92 [==============================] - 7s 74ms/step - loss: 0.3825 - accuracy: 0.8787 - val_loss: 0.4700 - val_accuracy: 0.8392
Epoch 11/20
92/92 [==============================] - 7s 74ms/step - loss: 0.3756 - accuracy: 0.8784 - val_loss: 0.4812 - val_accuracy: 0.8311
Epoch 12/20
92/92 [==============================] - 7s 75ms/step - loss: 0.3703 - accuracy: 0.8798 - val_loss: 0.4631 - val_accuracy: 0.8365
Epoch 13/20
92/92 [==============================] - 7s 74ms/step - loss: 0.3640 - accuracy: 0.8856 - val_loss: 0.4651 - val_accuracy: 0.8338
Epoch 14/20
92/92 [==============================] - 7s 75ms/step - loss: 0.3591 - accuracy: 0.8828 - val_loss: 0.4597 - val_accuracy: 0.8365
Epoch 15/20
92/92 [==============================] - 7s 75ms/step - loss: 0.3527 - accuracy: 0.8849 - val_loss: 0.4601 - val_accuracy: 0.8338
Epoch 16/20
92/92 [==============================] - 7s 75ms/step - loss: 0.3483 - accuracy: 0.8869 - val_loss: 0.4510 - val_accuracy: 0.8447
Epoch 17/20
92/92 [==============================] - 7s 74ms/step - loss: 0.3424 - accuracy: 0.8941 - val_loss: 0.4484 - val_accuracy: 0.8420
Epoch 18/20
92/92 [==============================] - 7s 75ms/step - loss: 0.3375 - accuracy: 0.8910 - val_loss: 0.4537 - val_accuracy: 0.8420
Epoch 19/20
92/92 [==============================] - 7s 74ms/step - loss: 0.3334 - accuracy: 0.8910 - val_loss: 0.4454 - val_accuracy: 0.8420
Epoch 20/20
92/92 [==============================] - 7s 74ms/step - loss: 0.3285 - accuracy: 0.8978 - val_loss: 0.4570 - val_accuracy: 0.8420
In [112]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
In [113]:
for image_batch, label_batch in test_batches.take(1):
    images = image_batch
    labels = label_batch
    predictions = model.predict(image_batch)
    pass

predictions
Out[113]:
array([[6.42978121e-03, 2.70292051e-02, 8.17796171e-01, 1.16137145e-02,
        1.37131110e-01],
       [9.99899745e-01, 9.32928451e-05, 1.76209383e-07, 1.08059965e-06,
        5.60810213e-06],
       [7.35793531e-01, 1.33265764e-01, 8.93671159e-03, 1.06362648e-01,
        1.56413317e-02],
       [5.68800606e-04, 2.09205667e-03, 4.89121795e-01, 1.24602811e-03,
        5.06971300e-01],
       [7.35460758e-01, 2.39669383e-01, 4.16448154e-03, 9.35971923e-03,
        1.13456594e-02],
       [9.62615311e-01, 2.72168033e-02, 9.54214076e-04, 8.90906341e-03,
        3.04642133e-04],
       [3.84999090e-03, 3.07303877e-03, 9.42384243e-01, 2.80811843e-02,
        2.26115268e-02],
       [4.22896683e-01, 5.49593687e-01, 1.45423881e-04, 2.67198104e-02,
        6.44439075e-04],
       [1.68472528e-03, 1.33166584e-04, 6.38590101e-03, 9.62124109e-01,
        2.96721160e-02],
       [9.99835372e-01, 1.60882817e-04, 8.66828529e-08, 3.25045517e-06,
        3.98369366e-07],
       [9.80435610e-01, 1.79450903e-02, 1.09952263e-04, 1.25222676e-03,
        2.57112959e-04],
       [4.25106645e-01, 3.16474997e-02, 1.14259079e-01, 4.25614655e-01,
        3.37213604e-03],
       [5.31466352e-03, 4.04791674e-03, 6.94700330e-02, 8.31501186e-01,
        8.96663144e-02],
       [1.25461514e-03, 3.29343267e-02, 2.94253469e-01, 3.87472287e-02,
        6.32810414e-01],
       [9.23143923e-01, 5.27782738e-02, 4.26190486e-03, 5.50384959e-03,
        1.43121257e-02],
       [3.83203551e-05, 1.24107522e-04, 1.93665177e-03, 9.97528851e-01,
        3.72065027e-04],
       [2.88821175e-04, 1.00069703e-03, 9.40206528e-01, 6.85656630e-03,
        5.16474098e-02],
       [9.72628176e-01, 2.60965489e-02, 7.06231367e-05, 1.16615673e-03,
        3.85350686e-05],
       [1.47220888e-03, 1.32834733e-01, 6.29679143e-01, 1.32483065e-01,
        1.03530891e-01],
       [4.71902192e-02, 1.18673993e-02, 2.64782667e-01, 1.16031868e-02,
        6.64556503e-01],
       [1.80815995e-01, 1.21097878e-01, 4.61784631e-01, 1.30861238e-01,
        1.05440244e-01],
       [9.93946016e-01, 4.68763663e-03, 1.33566864e-04, 1.01430784e-03,
        2.18506844e-04],
       [1.83873102e-02, 5.45701504e-01, 3.83173436e-01, 1.26130017e-03,
        5.14763854e-02],
       [9.98713493e-01, 7.06006249e-04, 5.15218489e-05, 5.04806347e-04,
        2.41497874e-05],
       [1.07614227e-01, 8.56315136e-01, 7.86418188e-03, 1.35372989e-02,
        1.46691101e-02],
       [1.79636721e-02, 5.79133593e-02, 1.43708110e-01, 6.92492306e-01,
        8.79225060e-02],
       [2.24142079e-03, 4.72155094e-01, 5.19774079e-01, 5.08255279e-03,
        7.46910519e-04],
       [1.14253256e-04, 4.43286088e-04, 2.74504013e-02, 1.16540715e-02,
        9.60337996e-01],
       [9.45156440e-04, 3.27201490e-03, 5.01334071e-02, 4.27089334e-02,
        9.02940512e-01],
       [1.77187115e-01, 4.26277816e-02, 3.56777869e-02, 7.11947262e-01,
        3.25601026e-02],
       [9.98899460e-01, 4.31945315e-04, 1.22347687e-04, 5.12194471e-04,
        3.40427359e-05],
       [5.85593842e-03, 4.16743290e-03, 6.86512515e-03, 5.25062643e-02,
        9.30605292e-01]], dtype=float32)
In [114]:
import numpy as np
predictions = np.argmax(predictions, axis=1)
predictions
Out[114]:
array([2, 0, 0, 4, 0, 0, 2, 1, 3, 0, 0, 3, 3, 4, 0, 3, 2, 0, 2, 4, 2, 0,
       1, 0, 1, 3, 2, 4, 4, 3, 0, 4])
In [115]:
plt.figure(figsize=(20, 12))

for idx, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
    plt.subplot(4, 8, idx+1)
    image = (image + 1) / 2
    plt.imshow(image)
    correct = label == prediction
    title = f'real: {label} / pred :{prediction}\n {correct}!'
    if not correct:
        plt.title(title, fontdict={'color': 'red'})
    else:
        plt.title(title, fontdict={'color': 'blue'})
    plt.axis('off')
In [116]:
count = 0
for image, label, prediction in zip(images, labels, predictions):
    correct = label == prediction
    if correct:
        count = count + 1

print(count / 32 * 100) # 약 95% 내외
84.375
In [ ]: